package com.cml.batisext.core.repository.mybatis.base.sqlprovider;

import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import com.cml.batisext.core.bean.base.DatabaseBean;
import com.cml.batisext.core.repository.mybatis.annotations.DbField;
import com.cml.batisext.core.repository.mybatis.annotations.DbTable;
import com.cml.batisext.core.repository.mybatis.base.BaseDao;
import com.cml.batisext.core.repository.mybatis.base.BaseDaoSqlProvider;

/**
 * 通用接口的update方法的sql生成器
 * @author 
 *
 */
public class BaseDaoUpdateBatchSqlProvider extends BaseDaoSqlProvider {

    @Override
    public String doBuildSql(Map<String, Object> params, Class<?> beanType) throws Exception {

        String tableName = beanType.getAnnotation(DbTable.class).table();
        List<Object> list = (List<Object>) params.get(BaseDao.LIST_PARAM_NAME);
        List<Field> fields = getAllDbField(beanType);
        String tmpColumnName = null;
        Class<?> tmpColumnType = null;
        Object fieldValue = null;
        Map<String, String> valueMap = new HashMap<String, String>();
        String whereValue = " ";
        String sqlSetField = "";
        for (Object listobj : list) {
            Long pk = getPKValue(listobj);
            if (pk == null) {
                throw new Exception("传入bean中id字段必须有值");
            }

            for (Field f : fields) {
                DbField field = f.getAnnotation(DbField.class);
                if (field == null) {
                    continue;
                }
                tmpColumnName = getColumnName(f, field);
                if ("`id`".equals(tmpColumnName) || "`version`".equals(tmpColumnName)) {
                    continue;
                }
                tmpColumnType = getColumnType(f, field);
                fieldValue = getColumnValue(listobj, f, beanType);
                if (fieldValue == null) {
                    continue;
                }
                String setvalue = " when " + pk + " then " + setupFieldValue(tmpColumnType, fieldValue);
                valueMap.put(tmpColumnName, (valueMap.get(tmpColumnName) == null ? "" : valueMap.get(tmpColumnName)) + setvalue);

            }
            Object versionValue = DatabaseBean.class.getMethod("getVersion").invoke(listobj);
            DatabaseBean.class.getMethod("setVersion", Integer.class).invoke(listobj, (Integer) versionValue + 1);
            whereValue += " when " + pk + " then (id=" + pk + " and version=" + versionValue + ") ";
        }
        Iterator<Map.Entry<String, String>> it = valueMap.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<String, String> entry = it.next();
            sqlSetField += entry.getKey() + "=case id " + entry.getValue() + " end, ";
        }
        String sql = "update " + tableName + " set " + sqlSetField + " version = version + 1 where case id " + whereValue + " end";
        return sql;
    }
}
